{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "SKD7jSvIn1T9"
   },
   "source": [
    "# Importing Modules\n",
    "\n",
    "The necessary modules are : os, opencv, numpy, tqdm, matplotlib, keras and sklearn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "nNtz_YYVnzcZ",
    "outputId": "21000b0d-93c5-4dcf-96f3-4bedc5d90a3c"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import cv2\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, BatchNormalization, Activation, add\n",
    "from keras.models import Model, model_from_json\n",
    "from keras.optimizers import Adam\n",
    "from keras.layers.advanced_activations import ELU, LeakyReLU\n",
    "from keras.utils.vis_utils import plot_model\n",
    "from keras import backend as K \n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import classification_report"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fb1Dw_RroBkh"
   },
   "source": [
    "# Downloading Data\n",
    "\n",
    "Here, we download the image data and the corresponding ground truth segmentation masks.\n",
    "\n",
    "In the paper we had not tested on the ISIC-17 Dataset, which will be used here for the demo."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "u5MQHA30oWEl"
   },
   "source": [
    "## Downloading the Images\n",
    "\n",
    "Please note that the links may be different when you are running this notebook. Please update the link accordingly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 309
    },
    "colab_type": "code",
    "id": "XrMkfbme0nG_",
    "outputId": "d3a31f2d-5e1d-4fc4-8796-79a135c287a5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2019-05-03 04:40:50--  https://challenge.kitware.com/api/v1/item/584ad08dcad3a51cc66c8e12/download\n",
      "Resolving challenge.kitware.com (challenge.kitware.com)... 54.208.189.152\n",
      "Connecting to challenge.kitware.com (challenge.kitware.com)|54.208.189.152|:443... connected.\n",
      "HTTP request sent, awaiting response... 303 See Other\n",
      "Location: https://s3.amazonaws.com/covalic-prod-assetstore/34/2a/342a8071e8714a32a697b33758d68154?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Expires=3600&X-Amz-Credential=AKIAITHBL3CJMECU3C4A%2F20190503%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-SignedHeaders=host&X-Amz-Date=20190503T044051Z&X-Amz-Signature=2453f040a58dd267b29d3c27aad7e7caec4a866da5ff993a7e98aee9e30eec6b [following]\n",
      "--2019-05-03 04:40:51--  https://s3.amazonaws.com/covalic-prod-assetstore/34/2a/342a8071e8714a32a697b33758d68154?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Expires=3600&X-Amz-Credential=AKIAITHBL3CJMECU3C4A%2F20190503%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-SignedHeaders=host&X-Amz-Date=20190503T044051Z&X-Amz-Signature=2453f040a58dd267b29d3c27aad7e7caec4a866da5ff993a7e98aee9e30eec6b\n",
      "Resolving s3.amazonaws.com (s3.amazonaws.com)... 54.231.40.130\n",
      "Connecting to s3.amazonaws.com (s3.amazonaws.com)|54.231.40.130|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 6229496702 (5.8G) [application/zip]\n",
      "Saving to: ‘ISIC_2017_images.zip’\n",
      "\n",
      "ISIC_2017_images.zi 100%[===================>]   5.80G  15.6MB/s    in 6m 29s  \n",
      "\n",
      "2019-05-03 04:47:21 (15.3 MB/s) - ‘ISIC_2017_images.zip’ saved [6229496702/6229496702]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "! wget \"https://challenge.kitware.com/api/v1/item/584ad08dcad3a51cc66c8e12/download\" -O ISIC_2017_images.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UCQn7FSHoiTk"
   },
   "source": [
    "## Downloading the Ground Truth Segmentation\n",
    "\n",
    "Please note that the links may be different when you are running this notebook. Please update the link accordingly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 309
    },
    "colab_type": "code",
    "id": "enE3vtZr1G5X",
    "outputId": "93e2377b-3359-410a-d9b0-d3356d1d422c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2019-05-03 04:47:24--  https://challenge.kitware.com/api/v1/item/584997d0cad3a51cc66c8e00/download\n",
      "Resolving challenge.kitware.com (challenge.kitware.com)... 54.208.189.152\n",
      "Connecting to challenge.kitware.com (challenge.kitware.com)|54.208.189.152|:443... connected.\n",
      "HTTP request sent, awaiting response... 303 See Other\n",
      "Location: https://s3.amazonaws.com/covalic-prod-assetstore/79/d1/79d1d1af650248df95b4b51497650ba0?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Expires=3600&X-Amz-Credential=AKIAITHBL3CJMECU3C4A%2F20190503%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-SignedHeaders=host&X-Amz-Date=20190503T044725Z&X-Amz-Signature=d0ec8e6d436fbdc278ac82b61954f91e76112bb66902122442be95a3947fffc8 [following]\n",
      "--2019-05-03 04:47:25--  https://s3.amazonaws.com/covalic-prod-assetstore/79/d1/79d1d1af650248df95b4b51497650ba0?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Expires=3600&X-Amz-Credential=AKIAITHBL3CJMECU3C4A%2F20190503%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-SignedHeaders=host&X-Amz-Date=20190503T044725Z&X-Amz-Signature=d0ec8e6d436fbdc278ac82b61954f91e76112bb66902122442be95a3947fffc8\n",
      "Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.162.205\n",
      "Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.162.205|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 9321981 (8.9M) [application/zip]\n",
      "Saving to: ‘ISIC_2017_masks.zip’\n",
      "\n",
      "ISIC_2017_masks.zip 100%[===================>]   8.89M  6.40MB/s    in 1.4s    \n",
      "\n",
      "2019-05-03 04:47:27 (6.40 MB/s) - ‘ISIC_2017_masks.zip’ saved [9321981/9321981]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "! wget \"https://challenge.kitware.com/api/v1/item/584997d0cad3a51cc66c8e00/download\" -O ISIC_2017_masks.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "2YFLs2aFoxgN"
   },
   "source": [
    "## Extracting the Zip Files\n",
    "\n",
    "Next the downloaded zip files are unzipped and the data is extracted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "pCy-G3jH2E9D"
   },
   "outputs": [],
   "source": [
    "! unzip ISIC_2017_images.zip -d ISIC_2017_images\n",
    "! unzip ISIC_2017_masks.zip -d ISIC_2017_masks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "xLkeS7Lxo_Po"
   },
   "source": [
    "# Constructing Training and Test Datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "gHo6ymwFpJzY"
   },
   "source": [
    "## Loading the Images\n",
    "\n",
    "We first load all the images and the corresponding segmentation masks. \n",
    "\n",
    "They are stored in two lists X, Y and respectively\n",
    "\n",
    "Moreover, the images are resized to 256x192"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "colab_type": "code",
    "id": "YbeQ1aytHTm3",
    "outputId": "d43c1e52-d889-4a4f-96aa-5ef4d20b649f"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 20/4001 [00:00<00:20, 191.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4001\n",
      "2000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 23%|██▎       | 928/4001 [00:13<01:03, 48.45it/s]"
     ]
    }
   ],
   "source": [
    "img_files = next(os.walk('ISIC_2017_images/ISIC-2017_Training_Data'))[2]\n",
    "msk_files = next(os.walk('ISIC_2017_masks/ISIC-2017_Training_Part1_GroundTruth'))[2]\n",
    "\n",
    "img_files.sort()\n",
    "msk_files.sort()\n",
    "\n",
    "print(len(img_files))\n",
    "print(len(msk_files))\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "X = []\n",
    "Y = []\n",
    "\n",
    "for img_fl in tqdm(img_files):    \n",
    "    if(img_fl.split('.')[-1]=='jpg'):\n",
    "                \n",
    "        \n",
    "        img = cv2.imread('ISIC_2017_images/ISIC-2017_Training_Data/{}'.format(img_fl), cv2.IMREAD_COLOR)\n",
    "        resized_img = cv2.resize(img,(256, 192), interpolation = cv2.INTER_CUBIC)\n",
    "        \n",
    "        X.append(resized_img)\n",
    "        \n",
    "        msk = cv2.imread('ISIC_2017_masks/ISIC-2017_Training_Part1_GroundTruth/{}'.format(img_fl.split('.')[0]+'_segmentation.png'), cv2.IMREAD_GRAYSCALE)\n",
    "        resized_msk = cv2.resize(msk,(256, 192), interpolation = cv2.INTER_CUBIC)\n",
    "        \n",
    "        Y.append(resized_msk)\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jYgEz9HQpgsR"
   },
   "source": [
    "## Train-Test Split\n",
    "\n",
    "The X, Y lists are converted to numpy arrays for convenience. \n",
    "Furthermore, the images are divided by 255 to bring down the pixel values to [0...1] range. On the other hand the segmentations masks are converted to binary (0 or 1) values.\n",
    "\n",
    "Using Sklearn *train_test_split* we split the data randomly into 80% training and 20% testing data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 139
    },
    "colab_type": "code",
    "id": "CxFIZv3715Dt",
    "outputId": "5a93db60-ed50-4122-8aa0-b26e42a601ca"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2000\n",
      "2000\n",
      "(1600, 192, 256, 3)\n",
      "(1600, 192, 256, 1)\n",
      "{0.12549019607843137, 0.24705882352941178, 0.5607843137254902, 0.5764705882352941, 0.5058823529411764, 0.5490196078431373, 0.5529411764705883, 0.5568627450980392, 0.5686274509803921, 0.5098039215686274, 0.06274509803921569, 0.06666666666666667, 0.15294117647058825, 0.18823529411764706, 0.21568627450980393, 0.12941176470588237, 0.13333333333333333, 0.07058823529411765, 0.19215686274509805, 0.19607843137254902, 0.25098039215686274, 0.28627450980392155, 0.3137254901960784, 0.3058823529411765, 0.34901960784313724, 0.3411764705882353, 0.07450980392156863, 0.3686274509803922, 0.3764705882352941, 0.403921568627451, 0.4117647058823529, 0.4392156862745098, 0.43137254901960786, 0.4745098039215686, 0.4666666666666667, 0.2627450980392157, 0.3215686274509804, 0.26666666666666666, 0.27058823529411763, 0.396078431372549, 0.5019607843137255, 0.3254901960784314, 0.3843137254901961, 0.38823529411764707, 0.5372549019607843, 0.49411764705882355, 0.5725490196078431, 0.5647058823529412, 0.0784313725490196, 0.592156862745098, 0.5843137254901961, 0.6, 0.6078431372549019, 0.08235294117647059, 0.611764705882353, 0.6196078431372549, 0.6274509803921569, 0.6352941176470588, 0.1411764705882353, 0.6549019607843137, 0.6470588235294118, 0.16862745098039217, 0.20392156862745098, 0.7176470588235294, 0.23137254901960785, 0.14901960784313725, 0.21176470588235294, 0.08627450980392157, 0.1450980392156863, 0.20784313725490197, 0.25882352941176473, 0.29411764705882354, 0.30196078431372547, 0.32941176470588235, 0.09019607843137255, 0.39215686274509803, 0.42745098039215684, 0.4196078431372549, 0.4470588235294118, 0.4549019607843137, 0.09411764705882353, 0.2784313725490196, 0.5333333333333333, 0.5254901960784314, 0.5176470588235295, 0.5294117647058824, 0.5215686274509804, 0.5450980392156862, 0.5411764705882353, 0.4627450980392157, 0.09803921568627451, 0.1568627450980392, 0.1843137254901961, 0.2196078431372549, 0.10196078431372549, 0.16470588235294117, 0.22745098039215686, 0.1607843137254902, 0.2549019607843137, 0.2235294117647059, 0.27450980392156865, 0.2823529411764706, 0.30980392156862746, 0.3176470588235294, 0.34509803921568627, 0.33725490196078434, 0.10588235294117647, 0.37254901960784315, 0.3803921568627451, 0.40784313725490196, 0.4, 0.43529411764705883, 0.44313725490196076, 0.47058823529411764, 0.10980392156862745, 0.2901960784313726, 0.5137254901960784, 0.3568627450980392, 0.36470588235294116, 0.47843137254901963, 0.2980392156862745, 0.4823529411764706, 0.35294117647058826, 0.49019607843137253, 0.4980392156862745, 0.5882352941176471, 0.5803921568627451, 0.6039215686274509, 0.596078431372549, 0.11372549019607843, 0.6235294117647059, 0.615686274509804, 0.6392156862745098, 0.13725490196078433, 0.6588235294117647, 0.17254901960784313, 0.6784313725490196, 0.7019607843137254, 0.6862745098039216, 0.2, 0.23529411764705882, 0.054901960784313725, 0.17647058823529413, 0.11764705882352941, 0.23921568627450981, 0.24313725490196078, 0.1803921568627451, 0.3333333333333333, 0.3607843137254902, 0.12156862745098039, 0.058823529411764705, 0.4235294117647059, 0.41568627450980394, 0.4588235294117647, 0.45098039215686275, 0.48627450980392156}\n",
      "{0.0, 1.0}\n"
     ]
    }
   ],
   "source": [
    "print(len(X))\n",
    "print(len(Y))\n",
    "\n",
    "X = np.array(X)\n",
    "Y = np.array(Y)\n",
    "\n",
    "X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=3)\n",
    "\n",
    "Y_train = Y_train.reshape((Y_train.shape[0],Y_train.shape[1],Y_train.shape[2],1))\n",
    "Y_test = Y_test.reshape((Y_test.shape[0],Y_test.shape[1],Y_test.shape[2],1))\n",
    "\n",
    "X_train = X_train / 255\n",
    "X_test = X_test / 255\n",
    "Y_train = Y_train / 255\n",
    "Y_test = Y_test / 255\n",
    "\n",
    "Y_train = np.round(Y_train,0)\t\n",
    "Y_test = np.round(Y_test,0)\t\n",
    "\n",
    "print(X_train.shape)\n",
    "print(Y_train.shape)\n",
    "print(X_test.shape)\n",
    "print(Y_test.shape)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "C-Ajp2QVrMti"
   },
   "source": [
    "# MultiResUNet Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "21cbmiojrYrU"
   },
   "source": [
    "## Model Definition\n",
    "\n",
    "The MultiResUNet model as described in the [paper](https://arxiv.org/abs/1902.04049) can be found  [here](https://github.com/nibtehaz/MultiResUNet/blob/master/MultiResUNet.py)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "2nX7I1Wf_zEy"
   },
   "outputs": [],
   "source": [
    "def conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(1, 1), activation='relu', name=None):\n",
    "    '''\n",
    "    2D Convolutional layers\n",
    "    \n",
    "    Arguments:\n",
    "        x {keras layer} -- input layer \n",
    "        filters {int} -- number of filters\n",
    "        num_row {int} -- number of rows in filters\n",
    "        num_col {int} -- number of columns in filters\n",
    "    \n",
    "    Keyword Arguments:\n",
    "        padding {str} -- mode of padding (default: {'same'})\n",
    "        strides {tuple} -- stride of convolution operation (default: {(1, 1)})\n",
    "        activation {str} -- activation function (default: {'relu'})\n",
    "        name {str} -- name of the layer (default: {None})\n",
    "    \n",
    "    Returns:\n",
    "        [keras layer] -- [output layer]\n",
    "    '''\n",
    "\n",
    "    x = Conv2D(filters, (num_row, num_col), strides=strides, padding=padding, use_bias=False)(x)\n",
    "    x = BatchNormalization(axis=3, scale=False)(x)\n",
    "\n",
    "    if(activation == None):\n",
    "        return x\n",
    "\n",
    "    x = Activation(activation, name=name)(x)\n",
    "\n",
    "    return x\n",
    "\n",
    "\n",
    "def trans_conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(2, 2), name=None):\n",
    "    '''\n",
    "    2D Transposed Convolutional layers\n",
    "    \n",
    "    Arguments:\n",
    "        x {keras layer} -- input layer \n",
    "        filters {int} -- number of filters\n",
    "        num_row {int} -- number of rows in filters\n",
    "        num_col {int} -- number of columns in filters\n",
    "    \n",
    "    Keyword Arguments:\n",
    "        padding {str} -- mode of padding (default: {'same'})\n",
    "        strides {tuple} -- stride of convolution operation (default: {(2, 2)})\n",
    "        name {str} -- name of the layer (default: {None})\n",
    "    \n",
    "    Returns:\n",
    "        [keras layer] -- [output layer]\n",
    "    '''\n",
    "\n",
    "    x = Conv2DTranspose(filters, (num_row, num_col), strides=strides, padding=padding)(x)\n",
    "    x = BatchNormalization(axis=3, scale=False)(x)\n",
    "    \n",
    "    return x\n",
    "\n",
    "\n",
    "def MultiResBlock(U, inp, alpha = 1.67):\n",
    "    '''\n",
    "    MultiRes Block\n",
    "    \n",
    "    Arguments:\n",
    "        U {int} -- Number of filters in a corrsponding UNet stage\n",
    "        inp {keras layer} -- input layer \n",
    "    \n",
    "    Returns:\n",
    "        [keras layer] -- [output layer]\n",
    "    '''\n",
    "\n",
    "    W = alpha * U\n",
    "\n",
    "    shortcut = inp\n",
    "\n",
    "    shortcut = conv2d_bn(shortcut, int(W*0.167) + int(W*0.333) +\n",
    "                         int(W*0.5), 1, 1, activation=None, padding='same')\n",
    "\n",
    "    conv3x3 = conv2d_bn(inp, int(W*0.167), 3, 3,\n",
    "                        activation='relu', padding='same')\n",
    "\n",
    "    conv5x5 = conv2d_bn(conv3x3, int(W*0.333), 3, 3,\n",
    "                        activation='relu', padding='same')\n",
    "\n",
    "    conv7x7 = conv2d_bn(conv5x5, int(W*0.5), 3, 3,\n",
    "                        activation='relu', padding='same')\n",
    "\n",
    "    out = concatenate([conv3x3, conv5x5, conv7x7], axis=3)\n",
    "    out = BatchNormalization(axis=3)(out)\n",
    "\n",
    "    out = add([shortcut, out])\n",
    "    out = Activation('relu')(out)\n",
    "    out = BatchNormalization(axis=3)(out)\n",
    "\n",
    "    return out\n",
    "\n",
    "\n",
    "def ResPath(filters, length, inp):\n",
    "    '''\n",
    "    ResPath\n",
    "    \n",
    "    Arguments:\n",
    "        filters {int} -- [description]\n",
    "        length {int} -- length of ResPath\n",
    "        inp {keras layer} -- input layer \n",
    "    \n",
    "    Returns:\n",
    "        [keras layer] -- [output layer]\n",
    "    '''\n",
    "\n",
    "\n",
    "    shortcut = inp\n",
    "    shortcut = conv2d_bn(shortcut, filters, 1, 1,\n",
    "                         activation=None, padding='same')\n",
    "\n",
    "    out = conv2d_bn(inp, filters, 3, 3, activation='relu', padding='same')\n",
    "\n",
    "    out = add([shortcut, out])\n",
    "    out = Activation('relu')(out)\n",
    "    out = BatchNormalization(axis=3)(out)\n",
    "\n",
    "    for i in range(length-1):\n",
    "\n",
    "        shortcut = out\n",
    "        shortcut = conv2d_bn(shortcut, filters, 1, 1,\n",
    "                             activation=None, padding='same')\n",
    "\n",
    "        out = conv2d_bn(out, filters, 3, 3, activation='relu', padding='same')\n",
    "\n",
    "        out = add([shortcut, out])\n",
    "        out = Activation('relu')(out)\n",
    "        out = BatchNormalization(axis=3)(out)\n",
    "\n",
    "    return out\n",
    "\n",
    "\n",
    "def MultiResUnet(height, width, n_channels):\n",
    "    '''\n",
    "    MultiResUNet\n",
    "    \n",
    "    Arguments:\n",
    "        height {int} -- height of image \n",
    "        width {int} -- width of image \n",
    "        n_channels {int} -- number of channels in image\n",
    "    \n",
    "    Returns:\n",
    "        [keras model] -- MultiResUNet model\n",
    "    '''\n",
    "\n",
    "\n",
    "    inputs = Input((height, width, n_channels))\n",
    "\n",
    "    mresblock1 = MultiResBlock(32, inputs)\n",
    "    pool1 = MaxPooling2D(pool_size=(2, 2))(mresblock1)\n",
    "    mresblock1 = ResPath(32, 4, mresblock1)\n",
    "\n",
    "    mresblock2 = MultiResBlock(32*2, pool1)\n",
    "    pool2 = MaxPooling2D(pool_size=(2, 2))(mresblock2)\n",
    "    mresblock2 = ResPath(32*2, 3, mresblock2)\n",
    "\n",
    "    mresblock3 = MultiResBlock(32*4, pool2)\n",
    "    pool3 = MaxPooling2D(pool_size=(2, 2))(mresblock3)\n",
    "    mresblock3 = ResPath(32*4, 2, mresblock3)\n",
    "\n",
    "    mresblock4 = MultiResBlock(32*8, pool3)\n",
    "    pool4 = MaxPooling2D(pool_size=(2, 2))(mresblock4)\n",
    "    mresblock4 = ResPath(32*8, 1, mresblock4)\n",
    "\n",
    "    mresblock5 = MultiResBlock(32*16, pool4)\n",
    "\n",
    "    up6 = concatenate([Conv2DTranspose(\n",
    "        32*8, (2, 2), strides=(2, 2), padding='same')(mresblock5), mresblock4], axis=3)\n",
    "    mresblock6 = MultiResBlock(32*8, up6)\n",
    "\n",
    "    up7 = concatenate([Conv2DTranspose(\n",
    "        32*4, (2, 2), strides=(2, 2), padding='same')(mresblock6), mresblock3], axis=3)\n",
    "    mresblock7 = MultiResBlock(32*4, up7)\n",
    "\n",
    "    up8 = concatenate([Conv2DTranspose(\n",
    "        32*2, (2, 2), strides=(2, 2), padding='same')(mresblock7), mresblock2], axis=3)\n",
    "    mresblock8 = MultiResBlock(32*2, up8)\n",
    "\n",
    "    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(\n",
    "        2, 2), padding='same')(mresblock8), mresblock1], axis=3)\n",
    "    mresblock9 = MultiResBlock(32, up9)\n",
    "\n",
    "    conv10 = conv2d_bn(mresblock9, 1, 1, 1, activation='sigmoid')\n",
    "    \n",
    "    model = Model(inputs=[inputs], outputs=[conv10])\n",
    "\n",
    "    return model\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "2frZlpmFsv1f"
   },
   "source": [
    "## Auxiliary Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "degBGBWYsyNG"
   },
   "source": [
    "### Custom Metrics\n",
    "\n",
    "Since Keras does not have build-in support for computing Dice Coefficient or Jaccard Index (at the time of writing), the following functions are declared"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "xq8qfLqDA6q2"
   },
   "outputs": [],
   "source": [
    "def dice_coef(y_true, y_pred):\n",
    "    smooth = 0.0\n",
    "    y_true_f = K.flatten(y_true)\n",
    "    y_pred_f = K.flatten(y_pred)\n",
    "    intersection = K.sum(y_true_f * y_pred_f)\n",
    "    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)\n",
    "\n",
    "def jacard(y_true, y_pred):\n",
    "\n",
    "    y_true_f = K.flatten(y_true)\n",
    "    y_pred_f = K.flatten(y_pred)\n",
    "    intersection = K.sum ( y_true_f * y_pred_f)\n",
    "    union = K.sum ( y_true_f + y_pred_f - y_true_f * y_pred_f)\n",
    "\n",
    "    return intersection/union"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "VtvyeBXy8Mk3"
   },
   "source": [
    "### Saving Model \n",
    "\n",
    "Function to save the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "GCfX9MUYALar"
   },
   "outputs": [],
   "source": [
    "def saveModel(model):\n",
    "\n",
    "    model_json = model.to_json()\n",
    "\n",
    "    try:\n",
    "        os.makedirs('models')\n",
    "    except:\n",
    "        pass\n",
    "    \n",
    "    fp = open('models/modelP.json','w')\n",
    "    fp.write(model_json)\n",
    "    model.save_weights('models/modelW.h5')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AH3znjx-8vDq"
   },
   "source": [
    "### Evaluate the Model\n",
    "\n",
    "We evaluate the model on test data (X_test, Y_test). \n",
    "\n",
    "We compute the values of Jaccard Index and Dice Coeficient, and save the predicted segmentation of first 10 images. The best model is also saved\n",
    "\n",
    "(This could have been done using keras call-backs as well)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Tkit_YYvBQ7V"
   },
   "outputs": [],
   "source": [
    "def evaluateModel(model, X_test, Y_test, batchSize):\n",
    "    \n",
    "    try:\n",
    "        os.makedirs('results')\n",
    "    except:\n",
    "        pass \n",
    "    \n",
    "\n",
    "    yp = model.predict(x=X_test, batch_size=batchSize, verbose=1)\n",
    "\n",
    "    yp = np.round(yp,0)\n",
    "\n",
    "    for i in range(10):\n",
    "\n",
    "        plt.figure(figsize=(20,10))\n",
    "        plt.subplot(1,3,1)\n",
    "        plt.imshow(X_test[i])\n",
    "        plt.title('Input')\n",
    "        plt.subplot(1,3,2)\n",
    "        plt.imshow(Y_test[i].reshape(Y_test[i].shape[0],Y_test[i].shape[1]))\n",
    "        plt.title('Ground Truth')\n",
    "        plt.subplot(1,3,3)\n",
    "        plt.imshow(yp[i].reshape(yp[i].shape[0],yp[i].shape[1]))\n",
    "        plt.title('Prediction')\n",
    "\n",
    "        intersection = yp[i].ravel() * Y_test[i].ravel()\n",
    "        union = yp[i].ravel() + Y_test[i].ravel() - intersection\n",
    "\n",
    "        jacard = (np.sum(intersection)/np.sum(union))  \n",
    "        plt.suptitle('Jacard Index'+ str(np.sum(intersection)) +'/'+ str(np.sum(union)) +'='+str(jacard))\n",
    "\n",
    "        plt.savefig('results/'+str(i)+'.png',format='png')\n",
    "        plt.close()\n",
    "\n",
    "\n",
    "    jacard = 0\n",
    "    dice = 0\n",
    "    \n",
    "    \n",
    "    for i in range(len(Y_test)):\n",
    "        yp_2 = yp[i].ravel()\n",
    "        y2 = Y_test[i].ravel()\n",
    "        \n",
    "        intersection = yp_2 * y2\n",
    "        union = yp_2 + y2 - intersection\n",
    "\n",
    "        jacard += (np.sum(intersection)/np.sum(union))  \n",
    "\n",
    "        dice += (2. * np.sum(intersection) ) / (np.sum(yp_2) + np.sum(y2))\n",
    "\n",
    "    \n",
    "    jacard /= len(Y_test)\n",
    "    dice /= len(Y_test)\n",
    "    \n",
    "\n",
    "\n",
    "    print('Jacard Index : '+str(jacard))\n",
    "    print('Dice Coefficient : '+str(dice))\n",
    "    \n",
    "\n",
    "    fp = open('models/log.txt','a')\n",
    "    fp.write(str(jacard)+'\\n')\n",
    "    fp.close()\n",
    "\n",
    "    fp = open('models/best.txt','r')\n",
    "    best = fp.read()\n",
    "    fp.close()\n",
    "\n",
    "    if(jacard>float(best)):\n",
    "        print('***********************************************')\n",
    "        print('Jacard Index improved from '+str(best)+' to '+str(jacard))\n",
    "        print('***********************************************')\n",
    "        fp = open('models/best.txt','w')\n",
    "        fp.write(str(jacard))\n",
    "        fp.close()\n",
    "\n",
    "        saveModel(model)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "80IwnKtM9NHY"
   },
   "source": [
    "### Training the Model\n",
    "\n",
    "The model is trained and evaluated after each epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "wlOOh0-nA05L"
   },
   "outputs": [],
   "source": [
    "def trainStep(model, X_train, Y_train, X_test, Y_test, epochs, batchSize):\n",
    "\n",
    "    \n",
    "    for epoch in range(epochs):\n",
    "        print('Epoch : {}'.format(epoch+1))\n",
    "        model.fit(x=X_train, y=Y_train, batch_size=batchSize, epochs=1, verbose=1)     \n",
    "\n",
    "        evaluateModel(model,X_test, Y_test,batchSize)\n",
    "\n",
    "    return model "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4wJ5V9AJ9Ygu"
   },
   "source": [
    "## Define Model, Train and Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1105
    },
    "colab_type": "code",
    "id": "GJqeuZPhDZSK",
    "outputId": "b1344374-f69e-4805-a22d-b8152f07b754"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 1\n",
      "Epoch 1/1\n",
      "1600/1600 [==============================] - 95s 59ms/step - loss: 0.2404 - dice_coef: 0.5408 - jacard: 0.3714 - acc: 0.9544\n",
      "400/400 [==============================] - 7s 17ms/step\n",
      "Jacard Index : 0.739111843270437\n",
      "Dice Coefficient : 0.8201970248614221\n",
      "Epoch : 2\n",
      "Epoch 1/1\n",
      "1600/1600 [==============================] - 95s 59ms/step - loss: 0.2363 - dice_coef: 0.5441 - jacard: 0.3742 - acc: 0.9565\n",
      "400/400 [==============================] - 7s 17ms/step\n",
      "Jacard Index : 0.7576494724352857\n",
      "Dice Coefficient : 0.849239599793822\n",
      "Epoch : 3\n",
      "Epoch 1/1\n",
      "1600/1600 [==============================] - 95s 59ms/step - loss: 0.2344 - dice_coef: 0.5441 - jacard: 0.3741 - acc: 0.9549\n",
      "400/400 [==============================] - 7s 17ms/step\n",
      "Jacard Index : 0.73477063233349\n",
      "Dice Coefficient : 0.8258158003605212\n",
      "Epoch : 4\n",
      "Epoch 1/1\n",
      "1600/1600 [==============================] - 95s 59ms/step - loss: 0.2323 - dice_coef: 0.5442 - jacard: 0.3742 - acc: 0.9549\n",
      "400/400 [==============================] - 7s 17ms/step\n",
      "Jacard Index : 0.7418988453931786\n",
      "Dice Coefficient : 0.8323562169488263\n",
      "Epoch : 5\n",
      "Epoch 1/1\n",
      "1600/1600 [==============================] - 92s 57ms/step - loss: 0.2313 - dice_coef: 0.5447 - jacard: 0.3746 - acc: 0.9536\n",
      "400/400 [==============================] - 7s 16ms/step\n",
      "Jacard Index : 0.7290094940776954\n",
      "Dice Coefficient : 0.8232707436328333\n",
      "Epoch : 6\n",
      "Epoch 1/1\n",
      "1600/1600 [==============================] - 92s 57ms/step - loss: 0.2330 - dice_coef: 0.5393 - jacard: 0.3697 - acc: 0.9482\n",
      "400/400 [==============================] - 7s 16ms/step\n",
      "Jacard Index : 0.7179003495310056\n",
      "Dice Coefficient : 0.8171223533396312\n",
      "Epoch : 7\n",
      "Epoch 1/1\n",
      "1600/1600 [==============================] - 92s 57ms/step - loss: 0.2282 - dice_coef: 0.5453 - jacard: 0.3751 - acc: 0.9536\n",
      "400/400 [==============================] - 7s 16ms/step\n",
      "Jacard Index : 0.7455700229896323\n",
      "Dice Coefficient : 0.8377428294163479\n",
      "Epoch : 8\n",
      "Epoch 1/1\n",
      "1600/1600 [==============================] - 92s 57ms/step - loss: 0.2282 - dice_coef: 0.5437 - jacard: 0.3737 - acc: 0.9465\n",
      "400/400 [==============================] - 7s 16ms/step\n",
      "Jacard Index : 0.7424750221665336\n",
      "Dice Coefficient : 0.8310102778816032\n",
      "Epoch : 9\n",
      "Epoch 1/1\n",
      "1600/1600 [==============================] - 92s 57ms/step - loss: 0.2244 - dice_coef: 0.5479 - jacard: 0.3775 - acc: 0.9487\n",
      "400/400 [==============================] - 7s 16ms/step\n",
      "Jacard Index : 0.7778693739501523\n",
      "Dice Coefficient : 0.8592311207271003\n",
      "Epoch : 10\n",
      "Epoch 1/1\n",
      "1600/1600 [==============================] - 92s 57ms/step - loss: 0.2257 - dice_coef: 0.5449 - jacard: 0.3747 - acc: 0.9493\n",
      "400/400 [==============================] - 7s 16ms/step\n",
      "Jacard Index : 0.7845456434833971\n",
      "Dice Coefficient : 0.86388530151125\n",
      "***********************************************\n",
      "Jacard Index improved from 0.7816859684059847 to 0.7845456434833971\n",
      "***********************************************\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.engine.training.Model at 0x7f67ca2b09e8>"
      ]
     },
     "execution_count": 21,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = MultiResUnet(height=192, width=256, n_channels=3)\n",
    "\n",
    "model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[dice_coef, jacard, 'accuracy'])\n",
    "\n",
    "saveModel(model)\n",
    "\n",
    "fp = open('models/log.txt','w')\n",
    "fp.close()\n",
    "fp = open('models/best.txt','w')\n",
    "fp.write('-1.0')\n",
    "fp.close()\n",
    "    \n",
    "trainStep(model, X_train, Y_train, X_test, Y_test, epochs=10, batchSize=10)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "name": "ISIC.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
